'''
conv bn merge
'''

from modellist.VGG16 import VGGNet16
import torch
from collections import OrderedDict
from torchvision.models import vgg16_bn
from torch import nn
model = vgg16_bn(pretrained=True).to('cpu')
model.classifier = nn.Sequential(
    nn.Linear(512 * 7 * 7, 4096),
    nn.ReLU(True),
    nn.Dropout(),
    nn.Linear(4096, 4096),
    nn.ReLU(True),
    nn.Dropout(),
    nn.Linear(4096, 10),
).to('cpu')
model_state_dcit=torch.load("weights/best.pth")
new_model_state_dict={}
for k , v in model_state_dcit["model_state_dict"].items():
    new_model_state_dict[k]=v.to("cpu")
model.load_state_dict(new_model_state_dict)
"""  Functions  """
def merge(params, name, layer):
    # global variables
    global weights, bias
    global bn_param

    if layer == 'Convolution':
        # save weights and bias when meet conv layer
        if 'weight' in name:
            weights = params.data
            bias = torch.zeros(weights.size()[0])
        elif 'bias' in name:
            bias = params.data
        bn_param = {}

    elif layer == 'BatchNorm':
        # save bn params
        bn_param[name.split('.')[-1]] = params.data

        # running_var is the last bn param in pytorch
        if 'running_var' in name:
            # let us merge bn ~
            tmp = bn_param['weight'] / torch.sqrt(bn_param['running_var'] + 1e-5)
            weights = tmp.view(tmp.size()[0], 1, 1, 1) * weights
            bias = tmp*(bias - bn_param['running_mean']) + bn_param['bias']

            return weights, bias

    return None, None

print("start merging conv and bn")
new_weights=OrderedDict()
inner_product_flag=False

for name,params in new_model_state_dict.items():
    if len(params.size()) == 4:
        _, _ = merge(params, name, 'Convolution')
        prev_layer = name
    elif len(params.size()) == 1 and not inner_product_flag:
        w, b = merge(params, name, 'BatchNorm')
        if w is not None:
            new_weights[prev_layer] = w
            new_weights[prev_layer.replace('weight', 'bias')] = b
    else:
        # inner product layer
        # if meet inner product layer,
        # the next bias weight can be misclassified as 'BatchNorm' layer as len(params.size()) == 1
        new_weights[name] = params
        inner_product_flag = True
print('Aligning weight names...')
pytorch_net_key_list = list(model.state_dict().keys())
new_weights_key_list = list(new_weights.keys())
assert len(pytorch_net_key_list) == len(new_weights_key_list)
for index in range(len(pytorch_net_key_list)):
    new_weights[pytorch_net_key_list[index]] = new_weights.pop(new_weights_key_list[index])
SAVE=True
# save new weights
if SAVE:
    torch.save(new_weights, "./weights/merge_best.pth")

